#' @export
data_plot.performance_pp_check <- function(x, type = "density", ...) {
  # for data from "bayesplot::pp_check()", data is already in shape
  if (isTRUE(attributes(x)$is_stan) && type != "density") {
    class(x) <- c("data_plot", "see_performance_pp_check", "data.frame")
    attr(x, "info") <- list(
      xlab = attr(x, "response_name"),
      ylab = ifelse(identical(type, "density"), "Density", "Counts"),
      title = "Posterior Predictive Check",
      check_range = attr(x, "check_range"),
      bandwidth = attr(x, "bandwidth"),
      model_info = attr(x, "model_info")
    )
    return(x)
  }

  columns <- colnames(x)
  dataplot <- stats::reshape(
    x,
    times = columns,
    timevar = "key",
    v.names = "values",
    varying = list(columns),
    direction = "long"
  )

  if (is.factor(dataplot[["values"]])) {
    dataplot[["values"]] <- as.character(dataplot[["values"]])
  }

  dataplot <- dataplot[, 1:(ncol(dataplot) - 1), drop = FALSE]
  dataplot$key[dataplot$key != "y"] <- "Model-predicted data"
  dataplot$key[dataplot$key == "y"] <- "Observed data"
  dataplot$grp <- rep(seq_len(ncol(x)), each = nrow(x))

  attr(dataplot, "info") <- list(
    xlab = attr(x, "response_name"),
    ylab = ifelse(identical(type, "density"), "Density", "Counts"),
    title = "Posterior Predictive Check",
    check_range = attr(x, "check_range"),
    bandwidth = attr(x, "bandwidth"),
    model_info = attr(x, "model_info")
  )

  class(dataplot) <- unique(c(
    "data_plot",
    "see_performance_pp_check",
    class(dataplot)
  ))
  dataplot
}


# Plot --------------------------------------------------------------------

#' Plot method for posterior predictive checks
#'
#' The `plot()` method for the `performance::check_predictions()` function.
#'
#' @param alpha_line Numeric value specifying alpha of lines indicating `yrep`.
#' @param style A ggplot2-theme.
#' @param type Plot type for the posterior predictive checks plot. Can be `"density"`
#' (default), `"discrete_dots"`, `"discrete_interval"` or `"discrete_both"` (the
#' `discrete_*` options are appropriate for models with discrete - binary, integer
#' or ordinal etc. - outcomes).
#' @param x_limits Numeric vector of length 2 specifying the limits of the x-axis.
#' If not `NULL`, will zoom in the x-axis to the specified limits.
#' @inheritParams data_plot
#' @inheritParams plot.see_check_normality
#' @inheritParams plot.see_parameters_distribution
#'
#' @return A ggplot2-object.
#'
#' @seealso See also the vignette about [`check_model()`](https://easystats.github.io/performance/articles/check_model.html).
#'
#' @examples
#' library(performance)
#'
#' model <- lm(Sepal.Length ~ Species * Petal.Width + Petal.Length, data = iris)
#' check_predictions(model)
#'
#' # dot-plot style for count-models
#' d <- iris
#' d$poisson_var <- rpois(150, 1)
#' model <- glm(
#'   poisson_var ~ Species + Petal.Length + Petal.Width,
#'   data = d,
#'   family = poisson()
#' )
#' out <- check_predictions(model)
#' plot(out, type = "discrete_dots")
#' @export
print.see_performance_pp_check <- function(
  x,
  linewidth = 0.5,
  size_point = 2,
  size_bar = 0.7,
  size_axis_title = base_size,
  size_title = 12,
  base_size = 10,
  alpha_line = 0.15,
  style = theme_lucid,
  colors = unname(social_colors(c("green", "blue"))),
  type = "density",
  x_limits = NULL,
  ...
) {
  orig_x <- x
  check_range <- isTRUE(attributes(x)$check_range)
  plot_type <- attributes(x)$type
  is_stan <- attributes(x)$is_stan

  if (
    missing(type) &&
      !is.null(plot_type) &&
      plot_type %in%
        c("density", "discrete_dots", "discrete_interval", "discrete_both")
  ) {
    type <- plot_type
  } else {
    type <- insight::validate_argument(
      type,
      c("density", "discrete_dots", "discrete_interval", "discrete_both")
    )
  }

  if (!inherits(x, "data_plot")) {
    x <- data_plot(x, type)
  }

  p1 <- .plot_pp_check(
    x,
    linewidth = linewidth,
    size_point = size_point,
    alpha_line = alpha_line,
    theme_style = style,
    colors = colors,
    base_size = base_size,
    size_title = size_title,
    size_axis_title = size_axis_title,
    type = type,
    x_limits = x_limits,
    is_stan = is_stan,
    ...
  )

  if (isTRUE(check_range)) {
    p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors)
    graphics::plot(plots(p1, p2, n_columns = 1))
  } else {
    suppressWarnings(graphics::plot(p1))
  }

  invisible(orig_x)
}


#' @rdname print.see_performance_pp_check
#' @export
plot.see_performance_pp_check <- function(
  x,
  linewidth = 0.5,
  size_point = 2,
  size_bar = 0.7,
  size_axis_title = base_size,
  size_title = 12,
  base_size = 10,
  alpha_line = 0.15,
  style = theme_lucid,
  colors = unname(social_colors(c("green", "blue"))),
  type = "density",
  x_limits = NULL,
  ...
) {
  orig_x <- x
  check_range <- isTRUE(attributes(x)$check_range)
  plot_type <- attributes(x)$type
  is_stan <- attributes(x)$is_stan

  if (
    missing(type) &&
      !is.null(plot_type) &&
      plot_type %in%
        c("density", "discrete_dots", "discrete_interval", "discrete_both")
  ) {
    # nolint
    type <- plot_type
  } else {
    type <- insight::validate_argument(
      type,
      c("density", "discrete_dots", "discrete_interval", "discrete_both")
    )
  }

  if (!inherits(x, "data_plot")) {
    x <- data_plot(x, type)
  }

  p1 <- .plot_pp_check(
    x,
    linewidth = linewidth,
    size_point = size_point,
    alpha_line = alpha_line,
    theme_style = style,
    base_size = base_size,
    size_axis_title = size_axis_title,
    size_title = size_title,
    colors = colors,
    type = type,
    x_limits = x_limits,
    is_stan = is_stan,
    ...
  )

  if (isTRUE(check_range)) {
    p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors)
    plots(p1, p2)
  } else {
    p1
  }
}


.plot_pp_check <- function(
  x,
  linewidth,
  size_point,
  alpha_line,
  theme_style,
  base_size = 10,
  size_axis_title = 10,
  size_title = 12,
  colors,
  type = "density",
  x_limits = NULL,
  is_stan = NULL,
  ...
) {
  info <- attr(x, "info")

  # discrete plot type from "bayesplot::pp_check()" returns a different data
  # structure, so we need to handle it differently
  if (isTRUE(is_stan) && type != "density") {
    return(.plot_check_predictions_stan_dots(
      x,
      colors,
      info,
      linewidth,
      size_point,
      alpha_line,
      ...
    ))
  }

  # default bandwidth, for smooting
  bandwidth <- info$bandwidth
  if (is.null(bandwidth)) {
    bandwidth <- "nrd"
  }

  minfo <- info$model_info
  suggest_dots <- (minfo$is_bernoulli ||
    minfo$is_count ||
    minfo$is_ordinal ||
    minfo$is_categorical)

  if (
    !is.null(type) &&
      type %in% c("discrete_dots", "discrete_interval", "discrete_both") &&
      suggest_dots
  ) {
    out <- .plot_check_predictions_dots(
      x,
      colors,
      info,
      linewidth,
      size_point,
      alpha_line,
      type,
      ...
    )
  } else {
    if (suggest_dots) {
      insight::format_alert(
        "The model has an integer or a discrete response variable.",
        "It is recommended to switch to a dot-plot style, e.g. `plot(check_model(model), type = \"discrete_dots\"`."
      )
    }
    # denity plot - for models that have no binary or count/ordinal outcome
    out <- .plot_check_predictions_density(
      x,
      colors,
      info,
      linewidth,
      alpha_line,
      bandwidth,
      ...
    )
  }

  dots <- list(...)
  if (isTRUE(dots[["check_model"]])) {
    out <- out +
      theme_style(
        base_size = base_size,
        plot.title.space = 3,
        axis.title.space = 5,
        axis.title.size = size_axis_title,
        plot.title.size = size_title
      )
  }

  if (isTRUE(dots[["adjust_legend"]]) || isTRUE(info$check_range)) {
    out <- out +
      ggplot2::theme(
        legend.position = "bottom",
        legend.margin = ggplot2::margin(0, 0, 0, 0),
        legend.box.margin = ggplot2::margin(-5, -5, -5, -5)
      )
  }

  if (!is.null(x_limits)) {
    out <- out + ggplot2::coord_cartesian(xlim = x_limits)
  }

  out
}


.plot_check_predictions_density <- function(
  x,
  colors,
  info,
  linewidth,
  alpha_line,
  bandwidth,
  ...
) {
  ggplot2::ggplot(x) +
    ggplot2::stat_density(
      mapping = ggplot2::aes(
        x = .data$values,
        group = .data$grp,
        color = .data$key,
        linewidth = .data$key,
        alpha = .data$key
      ),
      geom = "line",
      position = "identity",
      bw = bandwidth
    ) +
    ggplot2::scale_y_continuous() +
    ggplot2::scale_color_manual(
      values = c(
        "Observed data" = colors[1],
        "Model-predicted data" = colors[2]
      )
    ) +
    ggplot2::scale_linewidth_manual(
      values = c(
        "Observed data" = 1.7 * linewidth,
        "Model-predicted data" = linewidth
      ),
      guide = "none"
    ) +
    ggplot2::scale_alpha_manual(
      values = c(
        "Observed data" = 1,
        "Model-predicted data" = alpha_line
      ),
      guide = "none"
    ) +
    ggplot2::labs(
      x = info$xlab,
      y = info$ylab,
      color = "",
      size = "",
      alpha = "",
      title = "Posterior Predictive Check",
      subtitle = "Model-predicted lines should resemble observed data line"
    ) +
    ggplot2::guides(
      color = ggplot2::guide_legend(reverse = TRUE),
      size = ggplot2::guide_legend(reverse = TRUE)
    )
}


.plot_check_predictions_dots <- function(
  x,
  colors,
  info,
  linewidth,
  size_point,
  alpha_line,
  type = "discrete_dots",
  ...
) {
  # make sure we have a factor, so "table()" generates frequencies for all levels
  # for each group - we need tables of same size to bind data frames
  x$values <- as.factor(x$values)
  x <- stats::aggregate(x["values"], list(grp = x$grp), table)
  x <- cbind(
    data.frame(key = "Model-predicted data", stringsAsFactors = FALSE),
    x
  )
  x <- cbind(x[1:2], as.data.frame(x[[3]]))
  x$key[nrow(x)] <- "Observed data"
  x <- datawizard::data_to_long(
    x,
    select = -1:-2,
    names_to = "x",
    values_to = "count"
  )
  if (insight::n_unique(x$x) > 8) {
    x$x <- datawizard::to_numeric(x$x, dummy_factors = TRUE)
  }

  p1 <- p2 <- NULL

  if (!is.null(type) && type %in% c("discrete_interval", "discrete_both")) {
    centrality_dispersion <- function(i) {
      c(
        count = stats::median(i, na.rm = TRUE),
        unlist(bayestestR::ci(i)[c("CI_low", "CI_high")])
      )
    }
    x_errorbars <- stats::aggregate(
      x["count"],
      list(x$x),
      centrality_dispersion
    )
    x_errorbars <- cbind(x_errorbars[1], as.data.frame(x_errorbars[[2]]))
    colnames(x_errorbars) <- c("x", "count", "CI_low", "CI_high")
    x_errorbars <- cbind(
      data.frame(key = "Model-predicted data", stringsAsFactors = FALSE),
      x_errorbars
    )

    x_tmp <- x[x$key == "Observed data", ]
    x_tmp$CI_low <- NA
    x_tmp$CI_high <- NA
    x_tmp$grp <- NULL

    x_errorbars <- rbind(x_errorbars, x_tmp)
    p1 <- ggplot2::ggplot() +
      ggplot2::geom_pointrange(
        data = x_errorbars[x_errorbars$key == "Model-predicted data", ],
        mapping = ggplot2::aes(
          x = .data$x,
          y = .data$count,
          ymin = .data$CI_low,
          ymax = .data$CI_high,
          color = .data$key
        ),
        position = ggplot2::position_nudge(x = 0.2),
        size = 0.4 * size_point,
        linewidth = linewidth,
        stroke = 0,
        shape = 16
      ) +
      ggplot2::geom_point(
        data = x_errorbars[x_errorbars$key == "Observed data", ],
        mapping = ggplot2::aes(
          x = .data$x,
          y = .data$count,
          color = .data$key
        ),
        size = 1.5 * size_point,
        stroke = 0,
        shape = 16
      )
  }

  if (!is.null(type) && type %in% c("discrete_dots", "discrete_both")) {
    if (is.null(p1)) {
      p2 <- ggplot2::ggplot()
    } else {
      p2 <- p1
    }
    p2 <- p2 +
      ggplot2::geom_point(
        data = x[x$key == "Model-predicted data", ],
        mapping = ggplot2::aes(
          x = .data$x,
          y = .data$count,
          group = .data$grp,
          color = .data$key
        ),
        alpha = alpha_line,
        position = ggplot2::position_jitter(width = 0.1, height = 0.02),
        size = 0.8 * size_point,
        stroke = 0,
        shape = 16
      ) +
      # for legend
      ggplot2::geom_point(
        data = x[x$key == "Observed data", ],
        mapping = ggplot2::aes(
          x = .data$x,
          y = .data$count,
          group = .data$grp,
          color = .data$key
        ),
        size = 0.8 * size_point
      ) +
      ggplot2::geom_point(
        data = x[x$key == "Observed data", ],
        mapping = ggplot2::aes(
          x = .data$x,
          y = .data$count
        ),
        size = size_point,
        shape = 21,
        colour = "white",
        fill = colors[1]
      )
  }

  if (is.null(p2)) {
    p <- p1
  } else {
    p <- p2
  }

  if (type == "discrete_interval") {
    subtitle <- "Model-predicted intervals should include observed data points"
  } else {
    subtitle <- "Model-predicted points should be close to observed data points"
  }

  p +
    ggplot2::scale_y_continuous() +
    ggplot2::scale_color_manual(
      values = c(
        "Observed data" = colors[1],
        "Model-predicted data" = colors[2]
      )
    ) +
    ggplot2::labs(
      x = info$xlab,
      y = info$ylab,
      color = "",
      size = "",
      alpha = "",
      title = "Posterior Predictive Check",
      subtitle = subtitle
    ) +
    ggplot2::guides(
      color = ggplot2::guide_legend(reverse = TRUE),
      size = ggplot2::guide_legend(reverse = TRUE)
    )
}


.plot_check_predictions_stan_dots <- function(
  x,
  colors,
  info,
  linewidth,
  size_point,
  alpha_line,
  ...
) {
  # make sure we have a factor, so "table()" generates frequencies for all levels
  # for each group - we need tables of same size to bind data frames
  x$Group[x$Group == "y"] <- "Observed data"
  x$Group[x$Group == "Mean"] <- "Model-predicted data"

  # sanity check, remove NA rows
  x <- x[!is.na(x$Count), ]

  p <- ggplot2::ggplot() +
    ggplot2::geom_pointrange(
      data = x[x$Group == "Model-predicted data", ],
      mapping = ggplot2::aes(
        x = .data$x,
        y = .data$Count,
        ymin = .data$CI_low,
        ymax = .data$CI_high,
        color = .data$Group
      ),
      position = ggplot2::position_nudge(x = 0.2),
      size = 0.4 * size_point,
      linewidth = linewidth,
      stroke = 0,
      shape = 16
    ) +
    ggplot2::geom_point(
      data = x[x$Group == "Observed data", ],
      mapping = ggplot2::aes(
        x = .data$x,
        y = .data$Count,
        color = .data$Group
      ),
      size = 1.5 * size_point,
      stroke = 0,
      shape = 16
    ) +
    ggplot2::scale_y_continuous() +
    ggplot2::scale_color_manual(
      values = c(
        "Observed data" = colors[1],
        "Model-predicted data" = colors[2]
      )
    ) +
    ggplot2::labs(
      x = info$xlab,
      y = info$ylab,
      color = "",
      size = "",
      alpha = "",
      title = "Posterior Predictive Check",
      subtitle = "Model-predicted intervals should include observed data points"
    ) +
    ggplot2::guides(
      color = ggplot2::guide_legend(reverse = TRUE),
      size = ggplot2::guide_legend(reverse = TRUE)
    )

  p
}


.plot_pp_check_range <- function(
  x,
  size_bar = 0.7,
  colors = unname(social_colors(c("green", "blue")))
) {
  original <- data.frame(
    x = c(min(x$y), max(x$y)),
    group = factor(c("Minimum", "Maximum"), levels = c("Minimum", "Maximum")),
    color = "Observed data",
    stringsAsFactors = FALSE
  )

  replicated <- rbind(
    data.frame(
      x = vapply(x[which(names(x) != "y")], min, numeric(1)),
      group = "Minimum",
      color = "Model-predicted data",
      stringsAsFactors = FALSE
    ),
    data.frame(
      x = vapply(x[which(names(x) != "y")], max, numeric(1)),
      group = "Maximum",
      color = "Model-predicted data",
      stringsAsFactors = FALSE
    )
  )
  replicated$group <- factor(replicated$group, levels = c("Minimum", "Maximum"))

  p <- ggplot2::ggplot(
    replicated,
    ggplot2::aes(x = .data$x, group = .data$group)
  ) +
    ggplot2::facet_wrap(~group, scales = "free_x")

  if (insight::n_unique(replicated$x) <= 12) {
    p <- p + ggplot2::geom_bar(width = size_bar, fill = colors[2], color = NA)
  } else if (.is_integer(replicated$x)) {
    p <- p +
      ggplot2::geom_bar(width = size_bar, fill = colors[2], color = NA) +
      ggplot2::scale_x_continuous(
        n.breaks = round(insight::n_unique(replicated$x) / 4)
      )
  } else {
    p <- p +
      ggplot2::geom_histogram(binwidth = size_bar, fill = colors[2], color = NA)
  }

  p +
    ggplot2::geom_vline(
      data = original,
      mapping = ggplot2::aes(xintercept = .data$x),
      color = colors[1],
      linewidth = 1
    ) +
    ggplot2::labs(
      x = NULL,
      y = NULL,
      subtitle = "Model-predicted extrema should contain observed data extrema"
    )
}
